Source code for hysop.backend.device.codegen.base.utils
# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
[docs]
class WriteOnceDict(dict):
def __init__(self, **kargs):
super().__init__(**kargs)
self.lock()
def __setitem__(self, key, val):
if (not self.allow_overwrites) and (key in self.keys()):
raise RuntimeError(
f"Key {key} already in use for variable {str(self[key])}!"
)
super().__setitem__(key, val)
[docs]
def lock(self):
self.allow_overwrites = False
[docs]
def release(self):
self.allow_overwrites = True
[docs]
def translate(self, key2key_dict):
out = WriteOnceDict()
for k in key2key_dict.keys():
out[k] = self[key2key_dict[k]]
return out
[docs]
class ReadDefaultWriteOnceDict(WriteOnceDict):
def __init__(self, default_val, *args, **kargs):
super().__init__(*args, **kargs)
self.default_val = default_val
def __getitem__(self, key):
if key not in self.keys():
return self.default_val
else:
return super().__getitem__(key)
[docs]
class VarDict(WriteOnceDict):
def __setitem__(self, key, val):
from hysop.backend.device.codegen.base.variables import CodegenVariable
if not isinstance(key, str):
raise TypeError("VarDict key should be a string!")
elif not isinstance(val, CodegenVariable):
raise TypeError("VarDict value should inherit CodegenVariable!")
else:
super().__setitem__(key, val)
[docs]
class ArgDict(WriteOnceDict):
def __init__(self, overloading_allowed=False, *args, **kargs):
super().__init__(*args, **kargs)
self.arg_order = []
self.overloading_allowed = overloading_allowed
def __setitem__(self, key, val):
from hysop.backend.device.codegen.base.variables import CodegenVariable
if not isinstance(key, str):
raise TypeError("ArgDict key should be a string!")
elif not isinstance(val, CodegenVariable):
raise TypeError("ArgDict value should inherit CodegenVariable!")
else:
if key in self.keys():
append = False
else:
append = True
super().__setitem__(key, val)
if append:
self.arg_order.append(key)
[docs]
def items(self):
return iter(
[(argname, self.__getitem__(argname)) for argname in self.arg_order]
)
[docs]
def update(self, other):
for key, val in other.items():
self[key] = val
return self
[docs]
def build_args(self):
function_proto_args = []
function_impl_args = []
constant_args = []
i = 0
for varname in self.arg_order:
var = self[varname]
if var.symbolic_mode and var.known():
constant_args.append(var)
elif var.is_symbolic():
prototype_arg = var.argument(impl=False)
implementation_arg = var.argument(impl=True)
function_proto_args.append(prototype_arg)
function_impl_args.append(implementation_arg)
i += 1
else:
assert var.known()
assert var.symbolic_mode == False
assert var.is_symbolic() == False
if len(function_impl_args) and len(function_proto_args[-1]):
if function_proto_args[-1][-1] == "\n":
function_proto_args[-1] = function_proto_args[-1][:-1]
if function_impl_args[-1][-1] == "\n":
function_impl_args[-1] = function_impl_args[-1][:-1]
return function_proto_args, function_impl_args, constant_args
[docs]
def function_name_suffix(self, return_type, known_args):
if not self.overloading_allowed:
return self.codegen_name_suffix(return_type, known_args)
suffix = f"({return_type})_"
for varname in self.arg_order:
var = self[varname]
if not var.is_symbolic():
suffix += f"_{var.name}={var.sval()}"
elif known_args and (varname in known_args):
tmp = var.copy()
tmp.set_value(known_args[varname])
suffix += f"_{var.name}={tmp.sval()}"
if suffix != "":
return "_" + self.hash(suffix)
else:
return ""
# handle type function overloading
[docs]
def codegen_name_suffix(self, return_type, known_args):
suffix = f"({return_type})_"
for varname in self.arg_order:
var = self[varname]
if not var.is_symbolic():
suffix += f"_({var.ctype}){var.name}={var.sval()}"
elif known_args and (varname in known_args):
tmp = var.copy()
tmp.set_value(known_args[varname])
suffix += f"_({var.ctype}){var.name}={tmp.sval()}"
else:
suffix += f"_({var.ctype}){var.name}"
if suffix != "":
return "_" + self.hash(suffix)
else:
return ""
# robust with up to 256 functions with the same basename
# max_fun = sqrt(16**nb) = 2**(2*nb)
[docs]
def hash(self, string):
return hashlib.sha1(string.encode("utf-8")).hexdigest()[:4]
[docs]
class SortedDict(dict):
@classmethod
def _key(cls, k):
if hasattr(k, "name"):
s = k.name
else:
s = str(k)
return s
[docs]
def keys(self):
keys = super().keys()
return list(sorted(keys, key=self._key))
[docs]
def iterkeys(self):
keys = super().keys()
return iter(sorted(keys, key=self._key))
[docs]
def values(self):
return list(self[k] for k in self.keys())
[docs]
def itervalues(self):
return iter(self[k] for k in self.keys())
def items(self):
return tuple((k, self[k]) for k in self.keys())
[docs]
def items(self):
return iter((k, self[k]) for k in self.keys())
def __iter__(self):
return self.iterkeys()